from elasticsearch import Elasticsearch
import tensorflow_hub as hub
import tensorflow.compat.v1 as tf
import pandas as pd
import numpy as np
import tensorflow_text as text

# es连接
es = Elasticsearch("http://127.0.0.1:9200")

# es索引名称
INDEX_NAME = "yiliao_vectors"

# 加载模型
graph = tf.Graph()
 
with tf.Session(graph = graph) as session:
    print("Loading pre-trained embeddings")
    embed = hub.load("./model")
    text_ph = tf.placeholder(tf.string)
    embeddings = embed(text_ph)
    
    print("Creating tensorflow session…")
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())

# 实现把样本数据导入到es，并把title向量化处理，存入字段title_vector
def import_data():
  df = pd.read_csv('./yiliao.csv')
#   print(df['title'][0])
  
  # title处理为向量
  vectors = session.run(embeddings, feed_dict={text_ph: df['title']})
  vector = []
  for i in vectors:
      vector.append(i)
  
  df["Embeddings"] = vector

  # 创建索引
  create_es_index()

#   actions = []
  for index, row in df.iterrows():
      
      doc = {
          # "id": index,
          "department": row["department"],
          "title": row["title"],
          "ask": row["ask"],
          "answer": row["answer"],
          "title_vector": row["Embeddings"]
      }
       # 当行数比较多时, es.bulk会超时，改为一行行写入
      es.index(index=INDEX_NAME, body=doc, id=index)
      print(index)
      # action = {"index": {"_index": INDEX_NAME, "_id": index}}
      # actions.append(action)
      # actions.append(doc)
  # 当行数比较多时, es.bulk会超时，改为一行行写入
  # es.bulk(index=INDEX_NAME, body=actions, refresh=True)



# 创建索引
def create_es_index():
  configurations = {
    "settings": {
      "index": {"number_of_replicas": 2},
      "analysis": {
        "filter": {
            "ngram_filter": {
                "type": "edge_ngram",
                "min_gram": 2,
                "max_gram": 15,
            }
        },
        "analyzer": {
            "ngram_analyzer": {
                "type": "custom",
                "tokenizer": "standard",
                "filter": ["lowercase", "ngram_filter"],
            }
        }
      }
    },
    "mappings": {
      "properties": {
        "title_vector": {
          "type": "dense_vector",
          "dims": 512
        },
      }
    }
  }

  if(es.indices.exists(index=INDEX_NAME)):
      print("索引已存在，会自动删除后再重建")
      es.indices.delete(index=INDEX_NAME)
      
  es.indices.create(  index=INDEX_NAME,
                      body=configurations
                  )
  

# 处理文本嵌入
def embed_text(text):
    vectors = session.run(embeddings, feed_dict={text_ph: text})
    return [vector.tolist() for vector in vectors]


# 用于测试
def test():
    while True:
        try:
            query = input("请输入搜索内容: ")
            print(query)
            do_query(query)
        except KeyboardInterrupt:
            return

# 执行语义搜索（向量查询）
def do_query(query):
    query_vector = embed_text([query])[0]
    # print(query_vector)
    
    source_fields = ["title", "ask", "answer"]
    
    response = es.search(
        index = INDEX_NAME,
        body={
                "_source": source_fields,
                "query": {
                    "script_score": {
                        "query": {
                            "match_all": {}
                        },
                        "script": {
                            "source": "cosineSimilarity(params.queryVector, doc['title_vector'])+1",
                            "params": {
                                "queryVector": query_vector
                            }
                        }
                    }
                }
            })
    print("您搜索的title是：", end="\n")
    for hit in response["hits"]["hits"]:
        print(hit["_source"]["title"], end="\n")
    print("\n")
    # print(response)


if __name__ == '__main__':
  # 样本数据导入es，对title进行嵌入处理
  # import_data()
  
  # 导入后，执行测试
   test()